from agent.common.network import CommonNet, MLP, DuelingNet


def cnn_factory(*args, **kargs):
    return CommonNet(*args, **kargs)


def mlp_factory(*args, **kargs):
    return CommonNet(*args, **kargs, linear_only=True)


def dueling_factory(A_kargs, V_kargs, net_type='cnn'):
    def _dueling_factory(*args, **kargs):
        net = CommonNet(*args, **kargs, linear_only=(net_type != 'cnn'))
        return DuelingNet(
            net,
            MLP(net.output_dim, net.action_dim, **A_kargs),
            MLP(net.output_dim, 1, **V_kargs))
    return _dueling_factory
